from torch.utils.data import Dataset
import csv
import numpy as np


LableEncode = {
    "N":0,
    "A":1,
    "O":2,
    "~":3
}

LableDecode = ["N", "A", "O", "~"]


class ECGDataSet(Dataset):
    def __init__(self, dir:str):
        self.dir = dir
        with open(self.dir+"REFERENCE.csv", "r") as f:
            reader = csv.reader(f)
            self.reference = list(reader)

    def __getitem__(self, index):
        npfilename = "{}{}.npy".format(self.dir, self.reference[index][0])
        data = np.load(npfilename)
        data = data.astype(np.float32)
        lable = LableEncode[self.reference[index][1]]
        return data, lable


    def __len__(self):
        return len(self.reference)
